{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Relay Sequential pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tvm\n",
    "from tvm import relay\n",
    "from tvm.relay.testing import run_infer_type\n",
    "\n",
    "\n",
    "def check_func(func, ref_func):\n",
    "    func = run_infer_type(func)\n",
    "    ref_func = run_infer_type(ref_func)\n",
    "    assert tvm.ir.structural_equal(func, ref_func)\n",
    "    \n",
    "def extract_var_func(mod, name):\n",
    "    var = mod.get_global_var(name)\n",
    "    func = mod[var]\n",
    "    return var, func\n",
    "    \n",
    "def get_rand(shape, dtype=\"float32\"):\n",
    "    return tvm.nd.array(np.random.rand(*shape).astype(dtype))\n",
    "\n",
    "def get_ref_log():\n",
    "    ref_log = relay.Function([x], relay.log(relay.add(x, x)))\n",
    "    return ref_log\n",
    "\n",
    "def get_ref_sub():\n",
    "    ref_sub = relay.Function([x, y], relay.subtract(relay.add(x, x), relay.add(y, y)))\n",
    "    return ref_sub\n",
    "\n",
    "def get_ref_abs():\n",
    "    shape = (5, 10)\n",
    "    tp = relay.TensorType(shape, \"float32\")\n",
    "    a = relay.var(\"a\", tp)\n",
    "    ref_abs = relay.Function([a], relay.abs(relay.add(a, a)))\n",
    "    return ref_abs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "shape = (10,)\n",
    "dtype = \"float32\"\n",
    "tp = relay.TensorType(shape, dtype)\n",
    "x = relay.var(\"x\", tp)\n",
    "y = relay.var(\"y\", tp)\n",
    "v_sub = relay.GlobalVar(\"mySub\")\n",
    "sub = relay.Function([x, y], relay.subtract(x, y))\n",
    "\n",
    "z = relay.var(\"z\", tp)\n",
    "v_log = relay.GlobalVar(\"myLog\")\n",
    "log = relay.Function([z], relay.log(z))\n",
    "\n",
    "mod = tvm.IRModule({v_sub: sub, v_log: log})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.helper import OptTester\n",
    "\n",
    "# 注册 module pass\n",
    "opt_tester = OptTester(mod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.transform.module_pass(opt_level=1)\n",
    "def mod_transform(expr, ctx):\n",
    "    return opt_tester.transform(expr, ctx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 注册 function pass.\n",
    "@relay.transform.function_pass(opt_level=1)\n",
    "def func_transform(expr, mod, ctx):\n",
    "    return opt_tester.transform(expr, ctx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 序列级 Pass\n",
    "passes = [mod_transform, func_transform]\n",
    "opt_level = 2\n",
    "pass_name = \"sequential\"\n",
    "sequential = tvm.transform.Sequential(passes=passes, opt_level=opt_level)\n",
    "pass_info = sequential.info\n",
    "assert pass_info.name == pass_name\n",
    "assert pass_info.opt_level == opt_level"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试 seq pass\n",
    "\n",
    "空白 pass："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "passes = []\n",
    "sequential = tvm.transform.Sequential(opt_level=1, passes=passes)\n",
    "ret_mod = sequential(mod)\n",
    "mod_func = ret_mod[v_sub]\n",
    "check_func(sub, mod_func)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模块级 pass："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "passes = [mod_transform]\n",
    "sequential = tvm.transform.Sequential(opt_level=1, passes=passes)\n",
    "with tvm.transform.PassContext(required_pass=[\"mod_transform\"]):\n",
    "    ret_mod = sequential(mod)\n",
    "# Check the subtract function.\n",
    "sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint)\n",
    "check_func(new_sub, sub)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 带作用域的 pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "shape = (1, 2, 3)\n",
    "c_data = np.array(shape).astype(\"float32\")\n",
    "tp = relay.TensorType(shape, \"float32\")\n",
    "\n",
    "def before():\n",
    "    c = relay.const(c_data)\n",
    "    x = relay.var(\"x\", tp)\n",
    "    y = relay.add(c, c)\n",
    "    y = relay.multiply(y, relay.const(2, \"float32\"))\n",
    "    y = relay.add(x, y)\n",
    "    z = relay.add(y, c)\n",
    "    z1 = relay.add(y, c)\n",
    "    z2 = relay.add(z, z1)\n",
    "    return relay.Function([x], z2)\n",
    "\n",
    "def expected():\n",
    "    x = relay.var(\"x\", tp)\n",
    "    c_folded = (c_data + c_data) * 2\n",
    "    y = relay.add(x, relay.const(c_folded))\n",
    "    z = relay.add(y, relay.const(c_data))\n",
    "    z1 = relay.add(z, z)\n",
    "    return relay.Function([x], z1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/4tb/lxw/libs/anaconda3/envs/py38/lib/python3.8/site-packages/tvm/driver/build_module.py:267: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "seq = tvm.transform.Sequential(\n",
    "    [\n",
    "        relay.transform.InferType(),\n",
    "        relay.transform.FoldConstant(),\n",
    "        relay.transform.EliminateCommonSubexpr(),\n",
    "        relay.transform.AlterOpLayout(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "mod = tvm.IRModule({\"main\": before()})\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    with tvm.target.Target(\"llvm\"):\n",
    "        mod = seq(mod)\n",
    "\n",
    "zz = mod[\"main\"]\n",
    "zexpected = run_infer_type(expected())\n",
    "assert tvm.ir.structural_equal(zz, zexpected)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 嵌套型 pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def before():\n",
    "    x = relay.var(\"x\", shape=(1, 16, 16, 16), dtype=\"float32\")\n",
    "    w = relay.var(\"w\", shape=(32, 16, 3, 3), dtype=\"float32\")\n",
    "    y = relay.nn.conv2d(x, w, padding=(1, 1))\n",
    "    y = relay.reshape(y, newshape=(1, 16, -1))\n",
    "    y = relay.reshape(y, newshape=(4, 8, -1, 16))\n",
    "    y = relay.reverse_reshape(y, newshape=(32, 0, -1))\n",
    "    return tvm.IRModule.from_expr(y)\n",
    "\n",
    "def expected():\n",
    "    x = relay.var(\"x\", shape=(1, 16, 16, 16), dtype=\"float32\")\n",
    "    w = relay.var(\"w\", shape=(32, 16, 3, 3), dtype=\"float32\")\n",
    "    y = relay.nn.conv2d(x, w, padding=(1, 1))\n",
    "    y = relay.reshape(y, newshape=(32, 16, 16))\n",
    "    return tvm.IRModule.from_expr(y)\n",
    "\n",
    "z = before()\n",
    "passes = [\n",
    "    tvm.transform.Sequential([relay.transform.SimplifyExpr()]),\n",
    "]\n",
    "with tvm.transform.PassContext(opt_level=1):\n",
    "    zz = tvm.transform.Sequential(passes)(z)\n",
    "\n",
    "expected = relay.transform.InferType()(expected())\n",
    "assert tvm.ir.structural_equal(zz, expected)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('py38': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
